#ifndef C_IMPLEMENT_GRAIN_GROWTH_H
#define C_IMPLEMENT_GRAIN_GROWTH_H

#include "nanovoid.h"

class GrainGrowthOneStep : public OneStep
{
    static const int dx[]; // = {0, 1, 0,-1, 0};
    static const int dy[]; // = {0, 0, 1, 0,-1};

    static const int laplen;
    static const valueType lapw[]; // = {-4, 1, 1, 1, 1};

    int Nx, Ny, size;
    uint n_grains, lshK, lshL;
    valueType h, h2, A, B, updateL, kappa, dtime, dtimeL;

public:
    GrainGrowthOneStep(int _Nx, int _Ny, uint _n_grains, uint _lshK, uint _lshL, valueType _h,
                       valueType _A, valueType _B, valueType _L, valueType _kappa,
                       valueType _dtime, valueType _lsh_r);

    void grab_vals(uint c, valueType *value_table, valueType *vals) override;
    void forward_one_step(valueType *vals, uint c, valueType *new_v) override;
    void assign_vals(valueType *old_v, uint c_old, valueType *new_v, uint c_new) override;

    void merge_neighbor_into_n_list(uint c, PNBucket *t) override;
    void move_out_neighbor_from_n_list(uint c, PNBucket *t) override;

    void encode_from_img(valueType *img);
    valueType *decode_to_img();
};




class GrainGrowthOneBack : public OneStep
{
    static const int dx[]; // = {0, 1, 0,-1, 0};
    static const int dy[]; // = {0, 0, 1, 0,-1};

    static const int laplen;
    static const valueType lapw[]; // = {-4, 1, 1, 1, 1};

    uint vals_len; // = 5 * 2 * 2; // first 13 is for cv, second 13 for ci, third 13 for eta
                                             // forth 13 is for dloss_dcv, fifth 13 is for dloss_dci, sixth 13 is for dloss_deta
    static const uint lap_len_1st = 5;
    static const uint lap_len_2nd = 13;

    int Nx, Ny, size;
    uint n_grains, lshK, lshL;
    valueType h, h2, A, B, updateL, kappa, dtime, dtimeL;
    valueType dh, dh2, dA, dB, dupdateL, dkappa, ddtime, ddtimeL;

public:
    GrainGrowthOneBack(int _Nx, int _Ny, uint _n_grains, uint _lshK, uint _lshL, valueType _h,
                       valueType _A, valueType _B, valueType _L, valueType _kappa,
                       valueType _dtime, valueType _lsh_r);
    
    void grab_vals(uint c, valueType *value_table, valueType *vals) override;
    void forward_one_step(valueType *vals, uint c, valueType *new_v) override;
    void assign_vals(valueType *old_v, uint c_old, valueType *new_v, uint c_new) override;

    void merge_neighbor_into_n_list(uint c, PNBucket *t) override;
    void move_out_neighbor_from_n_list(uint c, PNBucket *t) override;

    void encode_from_img(valueType *img, valueType *dloss);
    valueType **decode_to_img();

    // calculation helper function
    void forward_one_step_vals(valueType *vals, valueType *new_v); // calculate vals specifically

    // gradient
    void accumulate_weight_derivative(valueType *vals, uint c);

    void print_derivative();

    valueType *decode_derivative();
};

#endif // C_IMPLEMENT_GRAIN_GROWTH_H
